{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一、项目简介\n",
    "\n",
    "#### 1.1 背景\n",
    "    国际知识发现和数据挖掘竞赛（KDD-CUP）竞赛是由ACM 的数据挖掘及知识发现专委会（SIGKDD）主办的数据挖掘研究领域的国际顶级赛事。其中KDD的英文全称是Knowledge Discovery and Data Mining，即知识发现与数据挖掘。从1997年开始，每年举办一次\n",
    "    1999年KDD CUP主题为网络入侵，KDD CUP 提供了一份模拟入侵军方网路的数据标准数据集，数据集连接地址如下：\n",
    "    \n",
    "+ [KDD Cup 1999数据集](http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html)\n",
    "\n",
    "    数据集大小约为740M，为了方便学习，本实验我们使用总体数据的10%。每个连接信息包括发送的字节数、登录次数、TCP错误数等。数据集为CSV格式，每个连接占一行，包括42个特征。\n",
    "    \n",
    "#### 1.2 网络入侵\n",
    "    统计对各个端口在短时间内被远程访问的次数，就可以得到一个特征，该特征可以很好地预测端口扫描攻击。检测网络入侵是要找到与以往见过的连接不通的连接。K均值可根据每个网络连接的统计属性进行聚类，结果簇定义了历史连接类型，帮我们界定了正常的连接的区域。任何在区域之外的点都是不正常的。\n",
    "    \n",
    "#### 1.3 异常值检测\n",
    "    异常检测常用于检测欺诈、网络攻击、服务器及传感设备故障。在这些应用中，我们要能够找出以前从未见过的新型异常，如新欺诈方式、新入侵方法或新服务器故障模式。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二、数据准备\n",
    "\n",
    "#### 2.1 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.context import SparkContext\n",
    "from pyspark.sql.session import SparkSession"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "sc = SparkContext(\"local[*]\",\"Forest Cover Type\")\n",
    "spark = SparkSession(sc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_file = \"data/Kmeans/kddcup.data_10_percent_corrected\"\n",
    "kddCup = sc.textFile(data_file)\n",
    "kddCup = kddCup.map(lambda row : row.split(\",\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "随机查看一行数据:\n",
      "[['0', 'tcp', 'http', 'SF', '181', '5450', '0', '0', '0', '0', '0', '1', '0', '0', '0', '0', '0', '0', '0', '0', '0', '0', '8', '8', '0.00', '0.00', '0.00', '0.00', '1.00', '0.00', '0.00', '9', '9', '1.00', '0.00', '0.11', '0.00', '0.00', '0.00', '0.00', '0.00', 'normal.']]\n",
      "数据集列数： 42\n"
     ]
    }
   ],
   "source": [
    "print(\"随机查看一行数据:\")\n",
    "print(kddCup.take(1))\n",
    "print(\"数据集列数：\",len(kddCup.take(1)[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['duration',\n",
       " 'protocol_type',\n",
       " 'service',\n",
       " 'flag',\n",
       " 'src_bytes',\n",
       " 'dst_bytes',\n",
       " 'land',\n",
       " 'wrong_fragment',\n",
       " 'urgent',\n",
       " 'hot',\n",
       " 'num_failed_logins',\n",
       " 'logged_in',\n",
       " 'num_compromise',\n",
       " 'root_shell',\n",
       " 'su_attempted',\n",
       " 'num_root',\n",
       " 'num_file_creations',\n",
       " 'num_shells',\n",
       " 'num_access_files',\n",
       " 'num_outbound_cmds',\n",
       " 'is_host_login',\n",
       " 'is_guest_login',\n",
       " 'count',\n",
       " 'srv_count',\n",
       " 'serror_rate',\n",
       " 'srv_serror_rate',\n",
       " 'rerror_rate',\n",
       " 'srv_rerror_rate',\n",
       " 'same_srv_rate',\n",
       " 'diff_srv_rate',\n",
       " 'srv_diff_host_rate',\n",
       " 'dst_host_count',\n",
       " 'dst_host_srv_count',\n",
       " 'dst_host_same_srv_rate',\n",
       " 'dst_host_diff_srv_rate',\n",
       " 'dst_host_same_src_port_rate',\n",
       " 'dst_host_srv_diff_host_rate',\n",
       " 'dst_host_serror_rate',\n",
       " 'dst_host_srv_serror_rate',\n",
       " 'dst_host_rerror_rate',\n",
       " 'dst_host_srv_rerror_rate',\n",
       " 'label']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 加载数据集的字段信息\n",
    "header = open(\"data/Kmeans/header.txt\").read().split(\",\")\n",
    "header"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.2 数据集简介\n",
    "    由以上输出可知，该数据集共有42列数据，以下给出每一列字段的描述及数据类型\n",
    "<table border=\"\" width=\"80%\" nosave=\"\">\n",
    "<tbody><tr nosave=\"\">\n",
    "<td><i>feature name</i></td>\n",
    "\n",
    "<td nosave=\"\"><i>description&nbsp;</i></td>\n",
    "\n",
    "<td><i>type</i></td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>duration&nbsp;</td>\n",
    "\n",
    "<td>length (number of seconds) of the connection&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>protocol_type&nbsp;</td>\n",
    "\n",
    "<td>type of the protocol, e.g. tcp, udp, etc.&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>service&nbsp;</td>\n",
    "\n",
    "<td>network service on the destination, e.g., http, telnet, etc.&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>src_bytes&nbsp;</td>\n",
    "\n",
    "<td>number of data bytes from source to destination&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>dst_bytes&nbsp;</td>\n",
    "\n",
    "<td>number of data bytes from destination to source&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>flag&nbsp;</td>\n",
    "\n",
    "<td>normal or error status of the connection&nbsp;</td>\n",
    "\n",
    "<td>discrete&nbsp;</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>land&nbsp;</td>\n",
    "\n",
    "<td>1 if connection is from/to the same host/port; 0 otherwise&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>wrong_fragment&nbsp;</td>\n",
    "\n",
    "<td>number of ``wrong'' fragments&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>urgent&nbsp;</td>\n",
    "\n",
    "<td>number of urgent packets&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td><i>feature name</i></td>\n",
    "\n",
    "<td><i>description&nbsp;</i></td>\n",
    "\n",
    "<td><i>type</i></td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>hot&nbsp;</td>\n",
    "\n",
    "<td>number of ``hot'' indicators</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>num_failed_logins&nbsp;</td>\n",
    "\n",
    "<td>number of failed login attempts&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>logged_in&nbsp;</td>\n",
    "\n",
    "<td>1 if successfully logged in; 0 otherwise&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>num_compromised&nbsp;</td>\n",
    "\n",
    "<td>number of ``compromised'' conditions&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>root_shell&nbsp;</td>\n",
    "\n",
    "<td>1 if root shell is obtained; 0 otherwise&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>su_attempted&nbsp;</td>\n",
    "\n",
    "<td>1 if ``su root'' command attempted; 0 otherwise&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>num_root&nbsp;</td>\n",
    "\n",
    "<td>number of ``root'' accesses&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>num_file_creations&nbsp;</td>\n",
    "\n",
    "<td>number of file creation operations&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>num_shells&nbsp;</td>\n",
    "\n",
    "<td>number of shell prompts&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>num_access_files&nbsp;</td>\n",
    "\n",
    "<td>number of operations on access control files&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr nosave=\"\">\n",
    "<td>num_outbound_cmds</td>\n",
    "\n",
    "<td nosave=\"\">number of outbound commands in an ftp session&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>is_hot_login&nbsp;</td>\n",
    "\n",
    "<td>1 if the login belongs to the ``hot'' list; 0 otherwise&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>is_guest_login&nbsp;</td>\n",
    "\n",
    "<td>1 if the login is a ``guest''login; 0 otherwise&nbsp;</td>\n",
    "\n",
    "<td>discrete</td>\n",
    "</tr>\n",
    "<tr>\n",
    "<td><i>feature name</i></td>\n",
    "\n",
    "<td><i>description&nbsp;</i></td>\n",
    "\n",
    "<td><i>type</i></td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>count&nbsp;</td>\n",
    "\n",
    "<td>number of connections to the same host as the current connection in\n",
    "the past two seconds&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>serror_rate&nbsp;</td>\n",
    "\n",
    "<td>% of connections that have ``SYN'' errors&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>rerror_rate&nbsp;</td>\n",
    "\n",
    "<td>% of connections that have ``REJ'' errors&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>same_srv_rate&nbsp;</td>\n",
    "\n",
    "<td>% of connections to the same service&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>diff_srv_rate&nbsp;</td>\n",
    "\n",
    "<td>% of connections to different services&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>srv_count&nbsp;</td>\n",
    "\n",
    "<td>number of connections to the same service as the current connection\n",
    "in the past two seconds&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>srv_serror_rate&nbsp;</td>\n",
    "\n",
    "<td>% of connections that have ``SYN'' errors&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>srv_rerror_rate&nbsp;</td>\n",
    "\n",
    "<td>% of connections that have ``REJ'' errors&nbsp;</td>\n",
    "\n",
    "<td>continuous</td>\n",
    "</tr>\n",
    "\n",
    "<tr>\n",
    "<td>srv_diff_host_rate&nbsp;</td>\n",
    "\n",
    "<td>% of connections to different hosts&nbsp;</td>\n",
    "\n",
    "<td>continuous&nbsp;</td>\n",
    "</tr>\n",
    "    \n",
    "<tr>\n",
    "<td>label&nbsp;</td>\n",
    "\n",
    "<td>%connection type&nbsp;</td>\n",
    "\n",
    "<td>discrete&nbsp;</td>\n",
    "</tr>\n",
    "</tbody>\n",
    "</table>\n",
    "\n",
    "\n",
    "上表中，feature name 为字段名称，description 为字段描述，type 为字段类型\n",
    "字段类型分为以下两种\n",
    "+ continuous 连续性特征\n",
    "+ discrete 离散型特征\n",
    "\n",
    "最后一列label为目标变量，即网络连接的类型。网络连接类型总共分为23种：\n",
    "+ back dos\n",
    "+ buffer_overflow u2r\n",
    "+ ftp_write r2l\n",
    "+ guess_passwd r2l\n",
    "+ imap r2l\n",
    "+ ipsweep probe\n",
    "+ land dos\n",
    "+ loadmodule u2r\n",
    "+ multihop r2l\n",
    "+ neptune dos\n",
    "+ nmap probe\n",
    "+ perl u2r\n",
    "+ phf r2l\n",
    "+ pod dos\n",
    "+ portsweep probe\n",
    "+ rootkit u2r\n",
    "+ satan probe\n",
    "+ smurf dos\n",
    "+ spy r2l\n",
    "+ teardrop dos\n",
    "+ warezclient r2l\n",
    "+ warezmaster r2l"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2.3 数据探索\n",
    "以下代码中，我们将对以上数据做如下操作\n",
    "+ 将数据集转换成DataFrame\n",
    "+ 统计目标变量每个类别的样本数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+-------------+-------+----+---------+---------+----+--------------+------+---+-----------------+---------+--------------+----------+------------+--------+------------------+----------+----------------+-----------------+-------------+--------------+-----+---------+-----------+---------------+-----------+---------------+-------------+-------------+------------------+--------------+------------------+----------------------+----------------------+---------------------------+---------------------------+--------------------+------------------------+--------------------+------------------------+-------+\n",
      "|duration|protocol_type|service|flag|src_bytes|dst_bytes|land|wrong_fragment|urgent|hot|num_failed_logins|logged_in|num_compromise|root_shell|su_attempted|num_root|num_file_creations|num_shells|num_access_files|num_outbound_cmds|is_host_login|is_guest_login|count|srv_count|serror_rate|srv_serror_rate|rerror_rate|srv_rerror_rate|same_srv_rate|diff_srv_rate|srv_diff_host_rate|dst_host_count|dst_host_srv_count|dst_host_same_srv_rate|dst_host_diff_srv_rate|dst_host_same_src_port_rate|dst_host_srv_diff_host_rate|dst_host_serror_rate|dst_host_srv_serror_rate|dst_host_rerror_rate|dst_host_srv_rerror_rate|  label|\n",
      "+--------+-------------+-------+----+---------+---------+----+--------------+------+---+-----------------+---------+--------------+----------+------------+--------+------------------+----------+----------------+-----------------+-------------+--------------+-----+---------+-----------+---------------+-----------+---------------+-------------+-------------+------------------+--------------+------------------+----------------------+----------------------+---------------------------+---------------------------+--------------------+------------------------+--------------------+------------------------+-------+\n",
      "|       0|          tcp|   http|  SF|      181|     5450|   0|             0|     0|  0|                0|        1|             0|         0|           0|       0|                 0|         0|               0|                0|            0|             0|    8|        8|       0.00|           0.00|       0.00|           0.00|         1.00|         0.00|              0.00|             9|                 9|                  1.00|                  0.00|                       0.11|                       0.00|                0.00|                    0.00|                0.00|                    0.00|normal.|\n",
      "|       0|          tcp|   http|  SF|      239|      486|   0|             0|     0|  0|                0|        1|             0|         0|           0|       0|                 0|         0|               0|                0|            0|             0|    8|        8|       0.00|           0.00|       0.00|           0.00|         1.00|         0.00|              0.00|            19|                19|                  1.00|                  0.00|                       0.05|                       0.00|                0.00|                    0.00|                0.00|                    0.00|normal.|\n",
      "+--------+-------------+-------+----+---------+---------+----+--------------+------+---+-----------------+---------+--------------+----------+------------+--------+------------------+----------+----------------+-----------------+-------------+--------------+-----+---------+-----------+---------------+-----------+---------------+-------------+-------------+------------------+--------------+------------------+----------------------+----------------------+---------------------------+---------------------------+--------------------+------------------------+--------------------+------------------------+-------+\n",
      "only showing top 2 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 将数据集封装成DataFrame\n",
    "kddCupDataFrame = spark.createDataFrame(kddCup,header)\n",
    "kddCupDataFrame.show(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如上结果显示，由于数据列较多，不方便查看，以下我们仅选择其中的部分数据列进行查看"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------+-------------+-------+----+-------+\n",
      "|duration|protocol_type|service|flag|  label|\n",
      "+--------+-------------+-------+----+-------+\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "|       0|          tcp|   http|  SF|normal.|\n",
      "+--------+-------------+-------+----+-------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kddCupDataFrame.select(\"duration\",\"protocol_type\",\"service\",\"flag\",\"label\").show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**统计label数据列，每种类别下样本数量**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "smurf. 280790\n",
      "neptune. 107201\n",
      "normal. 97278\n",
      "back. 2203\n",
      "satan. 1589\n",
      "ipsweep. 1247\n",
      "portsweep. 1040\n",
      "warezclient. 1020\n",
      "teardrop. 979\n",
      "pod. 264\n",
      "nmap. 231\n",
      "guess_passwd. 53\n",
      "buffer_overflow. 30\n",
      "land. 21\n",
      "warezmaster. 20\n",
      "imap. 12\n",
      "rootkit. 10\n",
      "loadmodule. 9\n",
      "ftp_write. 8\n",
      "multihop. 7\n",
      "phf. 4\n",
      "perl. 3\n",
      "spy. 2\n"
     ]
    }
   ],
   "source": [
    "from collections import OrderedDict\n",
    "# 由于总共23个类别，因此显示前23行数据\n",
    "labels = kddCup.map(lambda line: line[-1])\n",
    "label_counts = labels.countByValue()\n",
    "sorted_labels = OrderedDict(sorted(label_counts.items(), key=lambda t: t[1], reverse=True))\n",
    "for label, count in sorted_labels.items():\n",
    "    print(label, count)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三、数据预处理\n",
    "    在该步骤中主要完成以下个操作\n",
    "+ 删除数据集中非数值型特征(Kmeans算法要求输入的特征都为数值型特征)\n",
    "+ 数值型特征变量转化为浮点型\n",
    "+ 数值型特征标准化处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为方便数据预处理，此处重新加载数据集\n",
    "raw_data = sc.textFile(data_file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.1 删除非数值型特征"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import array\n",
    "from math import sqrt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据集中存在非数值型特征，聚类算法要求所有输入特征均为数值型\n",
    "# 为简单期间，以下方法直接将所有非数值型特征删除\n",
    "def parse_interaction(line):\n",
    "    \"\"\"\n",
    "    Parses a network data interaction.\n",
    "    \"\"\"\n",
    "    line_split = line.split(\",\")\n",
    "    clean_line_split = [line_split[0]]+line_split[4:-1]\n",
    "    return (line_split[-1], array([float(x) for x in clean_line_split]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将数据集转化为聚类算法输入特征\n",
    "parsed_data = raw_data.map(parse_interaction)\n",
    "# 数据集缓存到内存中，加快处理\n",
    "parsed_data_values = parsed_data.values().cache() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3.2 数值型特征标准化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.mllib.feature import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "standardizer = StandardScaler(True, True)\n",
    "standardizer_model = standardizer.fit(parsed_data_values)\n",
    "standardized_data_values = standardizer_model.transform(parsed_data_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[DenseVector([-0.0678, -0.0029, 0.1387, -0.0067, -0.0477, -0.0026, -0.0441, -0.0098, 2.397, -0.0057, -0.0106, -0.0047, -0.0056, -0.0112, -0.0099, -0.0276, 0.0, 0.0, -0.0373, -1.5214, -1.1566, -0.4641, -0.4635, -0.248, -0.2486, 0.537, -0.2552, -0.2036, -3.4515, -1.6943, 0.5994, -0.2829, -1.0221, -0.1586, -0.4644, -0.4632, -0.252, -0.2495])]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "standardized_data_values.take(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由以上结果可知，所有数值型特征均已标准化处理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 四、模型训练与评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.mllib.clustering import KMeans"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.1 计算样本到聚类中心点的欧式距离"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dist_to_centroid(datum, clusters):\n",
    "    \"\"\"\n",
    "    计算样本点到与之对应的聚类中心点的欧式距离\n",
    "    Args:\n",
    "        dataum：数据样本\n",
    "        clusters：聚类中心点\n",
    "    Return:\n",
    "        样本点到与之对应的聚类中心点的欧式距离\n",
    "    \"\"\"\n",
    "    cluster = clusters.predict(datum)\n",
    "    centroid = clusters.centers[cluster]\n",
    "    return sqrt(sum([x**2 for x in (centroid - datum)]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.2 实现聚类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def clustering_score(data, k):\n",
    "    \"\"\"\n",
    "    根据输入的数据集和聚类参数k，对数据样本进行聚类操作\n",
    "    Args:\n",
    "        data：需要聚类的数据样本\n",
    "        k：聚类中心点个数\n",
    "    Returns:\n",
    "        result：聚类中心点数k，聚类模型以及每个聚类簇中点到中心点的平均距离组成的元组\n",
    "    \"\"\"\n",
    "    clusters = KMeans.train(data, k, maxIterations=10, runs=5, initializationMode=\"random\")\n",
    "    result = (k, clusters, data.map(lambda datum: dist_to_centroid(datum, clusters)).mean())\n",
    "    print(\"Clustering score for k=%(k)d is %(score)f\" % {\"k\": k, \"score\": result[2]})\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.3 查找最佳K值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Calculating total in within cluster distance for different k values (10 to 25):\n"
     ]
    }
   ],
   "source": [
    "# 由于在2.3小节中已知数据集label类别为23，因此我们在10~25之间寻找最佳K值\n",
    "max_k = 25\n",
    "print(\"Calculating total in within cluster distance for different k values (10 to %(max_k)d):\" \\\n",
    "      % {\"max_k\": max_k})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/spark/python/pyspark/mllib/clustering.py:347: UserWarning: The param `runs` has no effect since Spark 2.0.0.\n",
      "  warnings.warn(\"The param `runs` has no effect since Spark 2.0.0.\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Clustering score for k=10 is 1.082973\n",
      "Clustering score for k=11 is 0.980190\n",
      "Clustering score for k=12 is 0.883681\n",
      "Clustering score for k=13 is 1.026729\n",
      "Clustering score for k=14 is 0.973560\n",
      "Clustering score for k=15 is 1.015176\n",
      "Clustering score for k=16 is 0.905715\n",
      "Clustering score for k=17 is 1.027638\n",
      "Clustering score for k=18 is 0.841901\n",
      "Clustering score for k=19 is 0.911386\n",
      "Clustering score for k=20 is 0.874412\n",
      "Clustering score for k=21 is 0.817983\n",
      "Clustering score for k=22 is 1.050033\n",
      "Clustering score for k=23 is 0.765774\n",
      "Clustering score for k=24 is 0.690021\n",
      "Clustering score for k=25 is 0.705851\n"
     ]
    }
   ],
   "source": [
    "# 通过不同K值，训练模型，寻找最佳K值\n",
    "scores = []\n",
    "for k in range(10,max_k+1):\n",
    "    result = clustering_score(standardized_data_values, k)\n",
    "    scores.append(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4.4 获取最佳评分的K值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最佳K值为 24\n"
     ]
    }
   ],
   "source": [
    "# Obtain min score k\n",
    "min_k = min(scores, key=lambda x: x[2])[0]\n",
    "print(\"最佳K值为 %(best_k)d\" % {\"best_k\": min_k})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "正在保存数据......\n",
      "数据已成功保存至：data/Kmeans/sample_standardized\n"
     ]
    }
   ],
   "source": [
    "# 使用最佳K值的聚类模型，为每一个数据样本分配聚类中心点，并保存数据\n",
    "best_model = min(scores, key=lambda x: x[2])[1]\n",
    "cluster_assignments_sample = standardized_data_values.map(lambda datum: str(best_model.predict(datum))+\",\"+\",\".join(map(str,datum))).sample(False,0.05)\n",
    "\n",
    "# Save assignment sample to file\n",
    "print(\"正在保存数据......\")\n",
    "save_path = \"data/Kmeans/sample_standardized\"\n",
    "cluster_assignments_sample.saveAsTextFile(\"data/Kmeans/sample_standardized\")\n",
    "print(\"数据已成功保存至：%s\"%(save_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
